import jax.nn as nn
import jax
import jax.numpy as jnp

from functools import partial
from flax import linen as nn
from src.models.transformers import eluplus1
from src.models.transformers import TruncatedLinearTransformer, OnlineUniversalLinearTransformerRTRL
from src.models.rnn import TruncatedVanillaRNN
from src.utils import *
from typing import NamedTuple, Optional,Any,Sequence

class TruncatedLinearTransformerPredictor(nn.Module):
    d_model:int
    n_heads:int
    d_ffc:int
    n_layers:int
    output_hidden_size:int
    truncation:int
    kernel_phi:Any=nn.elu

    @nn.compact
    def __call__(self,inputs):
        trf_model=TruncatedLinearTransformer(d_model=self.d_model,d_ffc=self.d_ffc,n_heads=self.n_heads,
                                                    truncation=self.truncation,n_layers=self.n_layers,kernel_phi=self.kernel_phi)
        trf_out=trf_model(inputs)
        pred=nn.Sequential([nn.Dense(self.output_hidden_size),jax.nn.relu,nn.Dense(1)])(trf_out[-1])
        return pred

class OLTTBBPTPredictor(nn.Module):
    n_layers:int
    d_model:int
    d_ffc:int
    n_heads:int 
    kernel_dim:int
    truncation:int
    output_hidden_size:int
    kernel_phi:Any=eluplus1
    use_recency_bias:bool=True
    
    @nn.compact
    def __call__(self,inputs):
        """
            Online Linear Transformer with full context trained using trucated TBPPT
        Args:
            inputs (_type_): shape (input_dim)

        Returns:
            jax.numpy.array: shape (1,)
        """
        input_dim=inputs.shape[-1]
        inputs_concat=self.variable('state','inputs_concat',jnp.full,(self.truncation,input_dim),0.1)
        inputs_concat.value=jnp.concatenate([inputs_concat.value[1:],inputs.reshape(1,-1)],axis=0)

        model=OnlineUniversalLinearTransformerRTRL(n_layers=self.n_layers,d_model=self.d_model,d_ffc=self.d_ffc,
                                        n_heads=self.n_heads,kernel_dim=self.kernel_dim,kernel_phi=self.kernel_phi,
                                        use_recency_bias=self.use_recency_bias)
        memory_state=self.variable('state','memory',model.initialize_memory)
        trf_out,new_memory=model(inputs_concat.value,memory_state.value)
        memory_state.value=tree_index(new_memory,0)
        pred=nn.Sequential([nn.Dense(self.output_hidden_size),jax.nn.relu,nn.Dense(1)])(trf_out[-1])
        return pred

    
        

class OLTRTRLPredictor:
    def __init__(self,n_layers,d_model,d_ffc,n_heads,kernel_dim,output_hidden_size,kernel_phi=nn.elu,use_recency_bias=True):
        self.output_hidden_size=output_hidden_size
        self.d_model=d_model
        self.trf_model=OnlineUniversalLinearTransformerRTRL(n_layers=n_layers,d_model=d_model,d_ffc=d_ffc,
                                        n_heads=n_heads,kernel_dim=kernel_dim,kernel_phi=kernel_phi,use_recency_bias=use_recency_bias)
        self.pred=nn.Sequential([nn.Dense(self.output_hidden_size),jax.nn.relu,nn.Dense(1)])
    
    def init_params_state(self,rng_key,inputs):
        """Returns params state and (memory, memory_gradients)

        Args:
            rng_key (_type_): _description_
            inputs (_type_): _description_

        Returns:
            _type_: _description_
        """
        memory=self.trf_model.initialize_memory()
        variables=self.trf_model.init(rng_key,inputs.reshape(1,-1),memory)
        state_trf,params_trf=variables.pop('params')
        memory_gradients=self.trf_model.initialize_memory_gradients(memory,params_trf)
        memory_trf=(memory,memory_gradients)
        variables=self.pred.init(rng_key,jnp.zeros(self.d_model))
        state_out,params_out=variables.pop('params')
        state={'trf':state_trf,'out':state_out}
        params={'trf':params_trf,'out':params_out}
        return params,state,memory_trf
    
    @partial(jax.jit, static_argnums=(0,))
    def __call__(self, inputs,target,params,state,memory_trf):
        inputs=inputs.reshape(1,-1)
        state_trf,state_out=state['trf'],state['out']
        params_trf,params_out=params['trf'],params['out']
        memory,memory_gradients=memory_trf
        def loss_fn(params_out,state_out,trf_out,target):
            pred=self.pred.apply({'params':params_out,**state_out},trf_out) #State never being updated though
            l = ((target - pred) ** 2).sum() # Replace with your loss here.
            return l
        (u_iplus1,delutplus1_deltheta),(memory,memory_gradients)=self.trf_model.apply({'params': params_trf,**state_trf},inputs,memory,memory_gradients,
                       method=self.trf_model.step_and_grad) #State never being updated though
        memory=tree_index(memory,0) #Since output is TXc and TXs, we need to index 
        memory_gradients=tree_index(memory_gradients,0)
        l,(out_grads,dell_delu)=jax.value_and_grad(loss_fn,argnums=(0,2))(params_out,state_out,u_iplus1,target)
        trf_grads=jax.tree_util.tree_map(lambda x:jnp.tensordot(dell_delu,x,axes=2),delutplus1_deltheta)
        grads={'trf':trf_grads,'out':out_grads}
        return l,grads,(memory,memory_gradients)


class VanillaRNNPredictor(nn.Module):
    d_model:int
    output_hidden_size:int
    truncation:int

    @nn.compact
    def __call__(self,inputs):
        model=TruncatedVanillaRNN(d_model=self.d_model,truncation=self.truncation)
        out=model(inputs)
        pred=nn.Sequential([nn.Dense(self.output_hidden_size),jax.nn.relu,nn.Dense(1)])(out)
        return pred